The following code loads a dataset on mushroom properties (originally from: http://archive.ics.uci.edu/ml/machine-learning-databases/mushroom/agaricus-lepiota.names) and fits gradient boosted trees
library(xgboost)
library(DiagrammeR)
data(agaricus.train, package='xgboost')
data(agaricus.test, package='xgboost')
set.seed(991)
dim(agaricus.train$data)
## [1] 6513 126
bst <- xgboost(data = agaricus.train$data, label = agaricus.train$label,
max_depth = 2, eta = 1, nthread = 2, nrounds = 2,
objective = "binary:logistic")
## [1] train-logloss:0.233376
## [2] train-logloss:0.136658
pred <- predict(bst, agaricus.test$data)
# confusion matrix are useful!
table(Actual = agaricus.test$label, Predicted = pred > 0.5)
## Predicted
## Actual FALSE TRUE
## 0 813 22
## 1 13 763
max_depth=2, eta=1 and
nrounds=2 do?xgb.plot.tree to draw the tree (it appears in your
browser; you need to export/save it from there)Let’s first look at our first observation to better understand xgboost
agaricus.test$data[1,]
## cap-shape=bell cap-shape=conical
## 1 0
## cap-shape=convex cap-shape=flat
## 0 0
## cap-shape=knobbed cap-shape=sunken
## 0 0
## cap-surface=fibrous cap-surface=grooves
## 0 0
## cap-surface=scaly cap-surface=smooth
## 1 0
## cap-color=brown cap-color=buff
## 0 0
## cap-color=cinnamon cap-color=gray
## 0 0
## cap-color=green cap-color=pink
## 0 0
## cap-color=purple cap-color=red
## 0 0
## cap-color=white cap-color=yellow
## 1 0
## bruises?=bruises bruises?=no
## 1 0
## odor=almond odor=anise
## 0 1
## odor=creosote odor=fishy
## 0 0
## odor=foul odor=musty
## 0 0
## odor=none odor=pungent
## 0 0
## odor=spicy gill-attachment=attached
## 0 0
## gill-attachment=descending gill-attachment=free
## 0 1
## gill-attachment=notched gill-spacing=close
## 0 1
## gill-spacing=crowded gill-spacing=distant
## 0 0
## gill-size=broad gill-size=narrow
## 1 0
## gill-color=black gill-color=brown
## 0 1
## gill-color=buff gill-color=chocolate
## 0 0
## gill-color=gray gill-color=green
## 0 0
## gill-color=orange gill-color=pink
## 0 0
## gill-color=purple gill-color=red
## 0 0
## gill-color=white gill-color=yellow
## 0 0
## stalk-shape=enlarging stalk-shape=tapering
## 1 0
## stalk-root=bulbous stalk-root=club
## 0 1
## stalk-root=cup stalk-root=equal
## 0 0
## stalk-root=rhizomorphs stalk-root=rooted
## 0 0
## stalk-root=missing stalk-surface-above-ring=fibrous
## 0 0
## stalk-surface-above-ring=scaly stalk-surface-above-ring=silky
## 0 0
## stalk-surface-above-ring=smooth stalk-surface-below-ring=fibrous
## 1 0
## stalk-surface-below-ring=scaly stalk-surface-below-ring=silky
## 0 0
## stalk-surface-below-ring=smooth stalk-color-above-ring=brown
## 1 0
## stalk-color-above-ring=buff stalk-color-above-ring=cinnamon
## 0 0
## stalk-color-above-ring=gray stalk-color-above-ring=orange
## 0 0
## stalk-color-above-ring=pink stalk-color-above-ring=red
## 0 0
## stalk-color-above-ring=white stalk-color-above-ring=yellow
## 1 0
## stalk-color-below-ring=brown stalk-color-below-ring=buff
## 0 0
## stalk-color-below-ring=cinnamon stalk-color-below-ring=gray
## 0 0
## stalk-color-below-ring=orange stalk-color-below-ring=pink
## 0 0
## stalk-color-below-ring=red stalk-color-below-ring=white
## 0 1
## stalk-color-below-ring=yellow veil-type=partial
## 0 1
## veil-type=universal veil-color=brown
## 0 0
## veil-color=orange veil-color=white
## 0 1
## veil-color=yellow ring-number=none
## 0 0
## ring-number=one ring-number=two
## 1 0
## ring-type=cobwebby ring-type=evanescent
## 0 0
## ring-type=flaring ring-type=large
## 0 0
## ring-type=none ring-type=pendant
## 0 1
## ring-type=sheathing ring-type=zone
## 0 0
## spore-print-color=black spore-print-color=brown
## 0 1
## spore-print-color=buff spore-print-color=chocolate
## 0 0
## spore-print-color=green spore-print-color=orange
## 0 0
## spore-print-color=purple spore-print-color=white
## 0 0
## spore-print-color=yellow population=abundant
## 0 0
## population=clustered population=numerous
## 0 0
## population=scattered population=several
## 1 0
## population=solitary habitat=grasses
## 0 0
## habitat=leaves habitat=meadows
## 0 1
## habitat=paths habitat=urban
## 0 0
## habitat=waste habitat=woods
## 0 0
Demonstration of first observation on xgboost…
knitr::include_graphics("data/boosting_working.png")
Same answer as…
pred[1]
## [1] 0.2858302
Plotting Tree
xgb.plot.tree(model = bst)
max_depth=2 and
eta=1 but with nrounds chosen to minimise
cross-validation loss. Use xgb.plot.tree to plot it.
Comment on the relative accuracy and complexity of the two modelsxgb_best = xgb.cv(data = agaricus.train$data, label = agaricus.train$label,
max_depth = 2, eta = 1, nthread = 2, nrounds = 30, nfold = 5,
objective = "binary:logistic", metrics = "error")
## [1] train-error:0.046522+0.000911 test-error:0.046522+0.003644
## [2] train-error:0.022263+0.000676 test-error:0.022263+0.002704
## [3] train-error:0.007063+0.000255 test-error:0.007063+0.001019
## [4] train-error:0.015200+0.000477 test-error:0.015201+0.001909
## [5] train-error:0.007063+0.000255 test-error:0.007063+0.001019
## [6] train-error:0.001689+0.000989 test-error:0.002303+0.002003
## [7] train-error:0.001228+0.000153 test-error:0.001228+0.000614
## [8] train-error:0.001228+0.000153 test-error:0.001228+0.000614
## [9] train-error:0.001152+0.000172 test-error:0.001228+0.000614
## [10] train-error:0.001152+0.000172 test-error:0.001228+0.000614
## [11] train-error:0.000960+0.000500 test-error:0.001075+0.000783
## [12] train-error:0.000422+0.000521 test-error:0.000767+0.000971
## [13] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [14] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [15] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [16] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [17] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [18] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [19] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [20] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [21] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [22] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [23] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [24] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [25] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [26] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [27] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [28] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [29] train-error:0.000000+0.000000 test-error:0.000000+0.000000
## [30] train-error:0.000000+0.000000 test-error:0.000000+0.000000
Note: the loss is “logloss”, we can use the metric “error” to measure the prediction error for binary classification
The best model appears to be run….. FINISH THIS
xgb.cv(data = agaricus.train$data, label = agaricus.train$label,
max_depth = 2, eta = 1, nthread = 2, nrounds = 3, nfold = 5,
objective = "binary:logistic", metrics = "error")
## [1] train-error:0.046522+0.001388 test-error:0.046523+0.005555
## [2] train-error:0.022263+0.001196 test-error:0.022264+0.004784
## [3] train-error:0.007063+0.000561 test-error:0.007062+0.002245
bst2 = xgboost(data = agaricus.train$data, label = agaricus.train$label,
max_depth = 2, eta = 1, nthread = 2, nrounds = 3,
objective = "binary:logistic")
## [1] train-logloss:0.233376
## [2] train-logloss:0.136658
## [3] train-logloss:0.082531
xgb.plot.tree(model = bst2)
eta to reduce
cross-validation loss. (think about a strategy for choosing values of
eta to try, but don’t try more than five or so different
ones)for (i in c(0.1, 0.25, 0.5, 0.75, 1.0)) {
print(i)
xgb.cv(data = agaricus.train$data, label = agaricus.train$label,
max_depth = 2, eta = i, nthread = 2, nrounds = 10, nfold = 5,
objective = "binary:logistic", metrics = "error")
}
## [1] 0.1
## [1] train-error:0.046522+0.001003 test-error:0.046521+0.004010
## [2] train-error:0.042569+0.001180 test-error:0.042683+0.005848
## [3] train-error:0.046522+0.001003 test-error:0.046521+0.004010
## [4] train-error:0.041609+0.000971 test-error:0.041608+0.003886
## [5] train-error:0.041609+0.000971 test-error:0.041608+0.003886
## [6] train-error:0.041609+0.000971 test-error:0.041608+0.003886
## [7] train-error:0.037694+0.007680 test-error:0.038997+0.007148
## [8] train-error:0.030554+0.008559 test-error:0.031014+0.011339
## [9] train-error:0.041609+0.000971 test-error:0.041608+0.003886
## [10] train-error:0.023338+0.000782 test-error:0.023338+0.003130
## [1] 0.25
## [1] train-error:0.046522+0.001022 test-error:0.046523+0.004091
## [2] train-error:0.046522+0.001022 test-error:0.046523+0.004091
## [3] train-error:0.023338+0.000415 test-error:0.023338+0.001661
## [4] train-error:0.041609+0.000793 test-error:0.041610+0.003172
## [5] train-error:0.009443+0.006902 test-error:0.009519+0.007306
## [6] train-error:0.015200+0.004545 test-error:0.015354+0.003435
## [7] train-error:0.013473+0.002327 test-error:0.013818+0.001813
## [8] train-error:0.018693+0.002386 test-error:0.019653+0.003904
## [9] train-error:0.019883+0.001061 test-error:0.020728+0.003730
## [10] train-error:0.020881+0.001198 test-error:0.020881+0.003481
## [1] 0.5
## [1] train-error:0.046522+0.001308 test-error:0.046521+0.005233
## [2] train-error:0.045179+0.001532 test-error:0.046061+0.005590
## [3] train-error:0.023530+0.001287 test-error:0.024874+0.003315
## [4] train-error:0.028904+0.003929 test-error:0.029939+0.007388
## [5] train-error:0.013665+0.003770 test-error:0.013359+0.003783
## [6] train-error:0.016582+0.001345 test-error:0.014892+0.005835
## [7] train-error:0.003032+0.002236 test-error:0.002918+0.001489
## [8] train-error:0.008560+0.003179 test-error:0.007522+0.004214
## [9] train-error:0.001996+0.000260 test-error:0.001996+0.001041
## [10] train-error:0.001689+0.000444 test-error:0.001689+0.001128
## [1] 0.75
## [1] train-error:0.046522+0.000978 test-error:0.046523+0.003913
## [2] train-error:0.039230+0.010643 test-error:0.039611+0.010708
## [3] train-error:0.026447+0.005216 test-error:0.026869+0.003434
## [4] train-error:0.016122+0.007104 test-error:0.016735+0.006854
## [5] train-error:0.010095+0.002944 test-error:0.009674+0.003353
## [6] train-error:0.002149+0.000818 test-error:0.001689+0.000895
## [7] train-error:0.002495+0.000993 test-error:0.002149+0.001128
## [8] train-error:0.001689+0.000352 test-error:0.001689+0.001128
## [9] train-error:0.001651+0.000464 test-error:0.001842+0.000921
## [10] train-error:0.001420+0.000773 test-error:0.001535+0.001284
## [1] 1
## [1] train-error:0.051014+0.009867 test-error:0.054350+0.013054
## [2] train-error:0.021188+0.002272 test-error:0.021649+0.005031
## [3] train-error:0.009788+0.005480 test-error:0.010593+0.007106
## [4] train-error:0.014087+0.002638 test-error:0.014585+0.002950
## [5] train-error:0.005950+0.002231 test-error:0.006449+0.001855
## [6] train-error:0.001305+0.000144 test-error:0.001689+0.001128
## [7] train-error:0.001228+0.000094 test-error:0.001228+0.000376
## [8] train-error:0.001305+0.000144 test-error:0.001689+0.001128
## [9] train-error:0.001228+0.000094 test-error:0.001228+0.000376
## [10] train-error:0.000461+0.000564 test-error:0.000614+0.000753
We looks at smallest errors (dips) of each learning rate to determine the number of iterations/runs we need. Learning rate of 1.0 requires just 3 runs, whereas unsurprisingly the minimum number of runs increases as the learning rate decreases. Learning rate of 0.1 requires 7 runs to get the min error.
dimnames(agaricus.train$data)[[2]].# just draw diagram
knitr::include_graphics("data/boosting_new_working.jpg")
After some research A and B are actually poisonous, however they’re both predicted as being not poisonous which is incorrect. C is predicted as non-poisonous and it is non-poisonous in reality which is correct. Not a useful model, probably trained on species in America as opposed to New Zealand.